"""
Prepare the given data to use it in the Trainer class (train.py):
Tokenize the segments and encode them to an id, pad the padding_token_ids
to a fixed length, create the attention masks, split the data in training and
validation data, some small conversion functions:
train_valid_to_tensor, convert_labels_to_int
And finally preparing training & test(validation) data and the
corresponding DataLoader.
"""

import numpy as np
from tqdm import tqdm

from config import (
    get_tokenizer,
    MAX_LEN
)


class BertPreprocessor():
    """
    docstring for BertPreprocessor.
    """
    def __init__(self):
        # data loading & parsing
        self.labels = [0, 1, 2]
        # getting BERT-tokenizer
        self.tokenizer = get_tokenizer()

    def tokenize_segments_to_id(
            self, to_be_tokenized_segments: np.ndarray) -> np.ndarray:
        """
    	trokenize all of the sentences and map the tokens to thier word IDs.
    	"""
        if not isinstance(to_be_tokenized_segments, np.ndarray):
            to_be_tokenized_segments = np.array(to_be_tokenized_segments,
                                                dtype=object)
        input_ids = []
        #print("Tokenizing segments...")
        for segment in to_be_tokenized_segments:
            encoded_segment = self.tokenizer.encode(
                segment,
                add_special_tokens=True,  # Add '[CLS]' and '[SEP]'
            )

            input_ids.append(encoded_segment)
        return input_ids

    def create_attention_masks(self, padding_token_ids: list) -> list:
        """
    	Create attention masks
    	"""
        attention_masks = []
        for sentence in padding_token_ids:
            # Create the attention mask.
            #   - If a token ID is 0, => padding, set the mask to 0.
            #   - If a token ID is > 0, => real token, set the mask to 1.
            att_mask = [int(token_id > 0) for token_id in sentence]

            # Store the attention mask for this sentence.
            attention_masks.append(att_mask)

        return attention_masks

    def preprocess(self, slim=True, segments=False, maxlen=MAX_LEN):
        """
        collection of preprocessing steps,
        + enables calling from other classes in an easier way
        """
        if isinstance(segments, bool) and not segments:
            self.segments = self.data_frame.segment.values
        else:
            self.segments = segments

        input_ids = self.tokenize_segments_to_id(self.segments)
        padding_token_ids = input_ids
        attention_masks = self.create_attention_masks(padding_token_ids)

        #print(input_ids)
        if slim:
            return input_ids, padding_token_ids, attention_masks
